{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data loading pipeline examples\n",
    "\n",
    "The purpose of this notebook is to illustrate reading Nifti files and test speed of different methods.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/acceleration/transform_speed.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0.4.0+35.g6adbcde\n",
      "Numpy version: 1.19.5\n",
      "Pytorch version: 1.7.1\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False\n",
      "MONAI rev id: 6adbcdee45c16f18f5b713575af3410437177311\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.2\n",
      "Nibabel version: 3.2.1\n",
      "scikit-image version: 0.18.1\n",
      "Pillow version: 7.0.0\n",
      "Tensorboard version: 2.2.0\n",
      "gdown version: 3.12.2\n",
      "TorchVision version: 0.8.2\n",
      "ITK version: 5.1.2\n",
      "tqdm version: 4.51.0\n",
      "lmdb version: 1.0.0\n",
      "psutil version: 5.8.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "import glob\n",
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "try:\n",
    "    torch.multiprocessing.set_start_method(\"spawn\")\n",
    "except RuntimeError:\n",
    "    pass\n",
    "\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.data import ArrayDataset, create_test_image_3d\n",
    "from monai.transforms import (\n",
    "    AddChannel,\n",
    "    Compose,\n",
    "    LoadImage,\n",
    "    RandAffine,\n",
    "    RandSpatialCrop,\n",
    "    Rotate,\n",
    "    ScaleIntensity,\n",
    "    EnsureType,\n",
    ")\n",
    "from monai.utils import first\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/data/medical\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory:\n",
    "    directory = os.path.join(directory, \"transform_speed\")\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. Preparing input data (nifti images)\n",
    "\n",
    "Create a number of test Nifti files, 3d single channel images with spatial size (256, 256, 256) voxels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(256, 256, 256)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"im{i}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i}.nii.gz\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare list of image names and segmentation names\n",
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Test image loading with minimal preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 1, 256, 256, 256]) torch.Size([3, 1, 256, 256, 256])\n"
     ]
    }
   ],
   "source": [
    "imtrans = Compose([LoadImage(image_only=True), AddChannel(), EnsureType()])\n",
    "\n",
    "segtrans = Compose([LoadImage(image_only=True), AddChannel(), EnsureType()])\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "loader = torch.utils.data.DataLoader(ds, batch_size=3, num_workers=8)\n",
    "\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 23.3 ms, sys: 133 ms, total: 156 ms\n",
      "Wall time: 8.6 s\n"
     ]
    }
   ],
   "source": [
    "%time data = next(iter(loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Test image-patch loading with CPU multi-processing:\n",
    "\n",
    "- rotate (256, 256, 256)-voxel in the plane axes=(1, 2)\n",
    "- extract random (64, 64, 64) patches\n",
    "- implemented in MONAI using ` scipy.ndimage.rotate`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 1, 64, 64, 64]) torch.Size([3, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "\n",
    "imtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        ScaleIntensity(),\n",
    "        AddChannel(),\n",
    "        Rotate(angle=np.pi / 4),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        EnsureType(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "segtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        AddChannel(),\n",
    "        Rotate(angle=np.pi / 4),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        EnsureType(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    ds, batch_size=3, num_workers=8, pin_memory=torch.cuda.is_available()\n",
    ")\n",
    "\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 20.4 ms, sys: 1.07 s, total: 1.09 s\n",
      "Wall time: 22.6 s\n"
     ]
    }
   ],
   "source": [
    "%time data = next(iter(loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(the above results were based on Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Test image-patch loading with preprocessing on GPU:\n",
    "\n",
    "- random rotate (256, 256, 256)-voxel in the plane axes=(1, 2)\n",
    "- extract random (64, 64, 64) patches\n",
    "- implemented in MONAI using native pytorch resampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 1, 64, 64, 64]) torch.Size([3, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "\n",
    "# same parameter with different interpolation mode for image and segmentation\n",
    "rand_affine_img = RandAffine(\n",
    "    prob=1.0,\n",
    "    rotate_range=np.pi / 4,\n",
    "    translate_range=(96, 96, 96),\n",
    "    spatial_size=(64, 64, 64),\n",
    "    mode=\"bilinear\",\n",
    "    as_tensor_output=True,\n",
    "    device=torch.device(\"cuda:0\"),\n",
    ")\n",
    "rand_affine_seg = RandAffine(\n",
    "    prob=1.0,\n",
    "    rotate_range=np.pi / 4,\n",
    "    translate_range=(96, 96, 96),\n",
    "    spatial_size=(64, 64, 64),\n",
    "    mode=\"nearest\",\n",
    "    as_tensor_output=True,\n",
    "    device=torch.device(\"cuda:0\"),\n",
    ")\n",
    "\n",
    "imtrans = Compose(\n",
    "    [LoadImage(image_only=True), ScaleIntensity(),\n",
    "     AddChannel(), rand_affine_img, EnsureType()]\n",
    ")\n",
    "\n",
    "segtrans = Compose([LoadImage(image_only=True),\n",
    "                    AddChannel(), rand_affine_seg, EnsureType()])\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "loader = torch.utils.data.DataLoader(ds, batch_size=3, num_workers=0)\n",
    "\n",
    "im, seg = first(loader)\n",
    "\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.25 s, sys: 1.06 s, total: 4.31 s\n",
      "Wall time: 4.31 s\n"
     ]
    }
   ],
   "source": [
    "%time data = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tesla V100-SXM3-32GB\n",
      "|===========================================================================|\n",
      "|                  PyTorch CUDA memory summary, device ID 0                 |\n",
      "|---------------------------------------------------------------------------|\n",
      "|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |\n",
      "|===========================================================================|\n",
      "|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Allocated memory      |   12288 KB |   88064 KB |    1188 MB |    1176 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Active memory         |   12288 KB |   88064 KB |    1188 MB |    1176 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| GPU reserved memory   |  159744 KB |  159744 KB |  159744 KB |       0 B  |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Non-releasable memory |    8192 KB |   77823 KB |     833 MB |     825 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Allocations           |       4    |      12    |     208    |     204    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Active allocs         |       4    |      12    |     208    |     204    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| GPU reserved segments |       7    |       7    |       7    |       0    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Non-releasable allocs |       1    |       4    |     133    |     132    |\n",
      "|===========================================================================|\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(torch.cuda.get_device_name(0))\n",
    "print(torch.cuda.memory_summary(0, abbreviated=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
